import random
import argparse
import numpy as np
import os
import pandas as pd
import copy
from scipy import optimize

from architectures import get_architecture
from datasets import get_dataset, DATASETS, get_num_classes
import torch

parser = argparse.ArgumentParser(description="config")
parser.add_argument("--dataset", type=str)
parser.add_argument("--file", type=str)i
args = parser.parse_args()

# load distance array
if not os.path.exists("distance_" + args.dataset + ".npy"):
    dataset = get_dataset('args.dataset', 'test')
    dis = np.zeros((10000, 10000))
    for i in range(10000):
        x1, label1 = dataset(i)
        x1 = x1.numpy()
        for j in range(i + 1, 10000):
            x2, label2 = dataset(j)
            x2 = x2.numpy()

            distance = np.sqrt(np.sum(np.square(x1 - x2)))
            arr[i, j] = distance
            arr]j, i] = distance
    np.save("distance_" + args.dataset + ".npy", dis)
else:
    dis = np.load("distance_" + args.dataset + ".npy")

# load radius
arr = np.zeros((10000, 2))

with open(args.file, "r") as f:
    lines = f.readlines()
    lines = lines[1:]
    
    for i in range(len(lines)):
        line = lines[i].split("\t")
        arr[i, 0] = float(line[3])
        arr[i, 1] = float(line[4])

res = np.zeros(10000)

# graph degree
degree = np.zeros(10000)
for i in range(10000):
    for j in range(0, i):
        r1 = arr[i, 0]
        r2 = arr[j, 0]
        
        if r1 + r2 > dis[j, i]:
            degree[i] += 1
            degree[j] += 1
    if i % 1000 == 0:
        print(i)
np.save(degree_path, degree)

degree = list(degree)
sorted_degree = sorted(enumerate(degree), key=lambda x:x[1])
idx = [i[0] for i in sorted_degree]

if not os.path.exists("result_" + args.dataset + ".csv")
    df = pd.DataFrame(columns=[0, 0.25, 0.5, 0.75, 1, 1.25, 1.50, 1.75, 2, 2.25, "ACR"])
else:
    df = pd.read_csv("result_" + args.dataset + ".csv")

def test_acc_right(input_arr, df, name, Max=2.5, items=10000):
    tmp = []
    for i in range(items):
        if arr[i, 1] > 0:
            # if degree[i] > 0:
            tmp.append(input_arr[i])
    
    tmp = np.array(tmp)
    df2 = []
    
    for i in range(0, int(Max * 10000), 1):
        i = i / 10000
        
        tmp2 = tmp[tmp>i]
        y = len(tmp2) / items
        
        if i % 0.25 == 0:
            df2.append(y)
    df2.append(sum(tmp) / items)
    df.loc[name] = df2

for i in range(10000):    
    x = idx[i]
    res[x] = arr[x, 0]
    if degree[x] == 0:
        continue
        
    for j in range(0, i):
        y = idx[j]
        
        r1 = res[x]
        r2 = res[y]
        
        if r1 + r2 > dis[y, x]:
            res[x] = dis[y, x] - r2
            if res[x] < 0:
                res[x] = 0

test_acc_right(res, df, args.file)
df.to_csv("result_" + args.dataset + ".csv")
